Bypass the softmax pytorch kernel that upscales to fp32#19865
Bypass the softmax pytorch kernel that upscales to fp32#19865telgamal-1 wants to merge 1 commit into
Conversation
Summary: Bypasses the PyTorch softmax kernel in `static_attention` that upscales activations to fp32, keeping the softmax computation in fp16. Also updates `norm.py` to handle the fp16 softmax output. Differential Revision: D106729898
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19865
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 Awaiting Approval, 2 Unrelated Failures, 6 Unclassified FailuresAs of commit 4260f2a with merge base 42581f1 ( UNCLASSIFIED FAILURES - DrCI could not classify the following jobs because the workflow did not run on the merge base. The failures may be pre-existing on trunk or introduced by this PR:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
@telgamal-1 has exported this pull request. If you are a Meta employee, you can view the originating Diff in D106729898. |
This PR needs a
|
Summary: Bypasses the PyTorch softmax kernel in
static_attentionthat upscales activations to fp32, keeping the softmax computation in fp16. Also updatesnorm.pyto handle the fp16 softmax output.Differential Revision: D106729898